{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "_You are currently looking at **version 1.5** of this notebook. To download notebooks and datafiles, as well as get help on Jupyter notebooks in the Coursera platform, visit the [Jupyter Notebook FAQ](https://www.coursera.org/learn/python-machine-learning/resources/bANLa) course resource._\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Assignment 2\n",
    "\n",
    "In this assignment you'll explore the relationship between model complexity and generalization performance, by adjusting key parameters of various supervised learning models. Part 1 of this assignment will look at regression and Part 2 will look at classification.\n",
    "\n",
    "## Part 1 - Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, run the following block to set up the variables needed for later sections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "np.random.seed(0)\n",
    "n = 15\n",
    "x = np.linspace(0,10,n) + np.random.randn(n)/5\n",
    "y = np.sin(x)+x/6 + np.random.randn(n)/10\n",
    "\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(x, y, random_state=0)\n",
    "\n",
    "# You can use this function to help you visualize the dataset by\n",
    "# plotting a scatterplot of the data points\n",
    "# in the training and test sets.\n",
    "def part1_scatter():\n",
    "    import matplotlib.pyplot as plt\n",
    "    %matplotlib notebook\n",
    "    plt.figure()\n",
    "    plt.scatter(X_train, y_train, label='training data')\n",
    "    plt.scatter(X_test, y_test, label='test data')\n",
    "    plt.legend(loc=4);\n",
    "    \n",
    "    \n",
    "# NOTE: Uncomment the function below to visualize the data, but be sure \n",
    "# to **re-comment it before submitting this assignment to the autograder**.   \n",
    "# part1_scatter()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 1\n",
    "\n",
    "Write a function that fits a polynomial LinearRegression model on the *training data* `X_train` for degrees 1, 3, 6, and 9. (Use PolynomialFeatures in sklearn.preprocessing to create the polynomial features and then fit a linear regression model) For each model, find 100 predicted values over the interval x = 0 to 10 (e.g. `np.linspace(0,10,100)`) and store this in a numpy array. The first row of this array should correspond to the output from the model trained on degree 1, the second row degree 3, the third row degree 6, and the fourth row degree 9.\n",
    "\n",
    "<img src=\"readonly/polynomialreg1.png\" style=\"width: 1000px;\"/>\n",
    "\n",
    "The figure above shows the fitted models plotted on top of the original data (using `plot_one()`).\n",
    "\n",
    "<br>\n",
    "*This function should return a numpy array with shape `(4, 100)`*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  2.53040195e-01,   2.69201547e-01,   2.85362899e-01,\n",
       "          3.01524251e-01,   3.17685603e-01,   3.33846955e-01,\n",
       "          3.50008306e-01,   3.66169658e-01,   3.82331010e-01,\n",
       "          3.98492362e-01,   4.14653714e-01,   4.30815066e-01,\n",
       "          4.46976417e-01,   4.63137769e-01,   4.79299121e-01,\n",
       "          4.95460473e-01,   5.11621825e-01,   5.27783177e-01,\n",
       "          5.43944529e-01,   5.60105880e-01,   5.76267232e-01,\n",
       "          5.92428584e-01,   6.08589936e-01,   6.24751288e-01,\n",
       "          6.40912640e-01,   6.57073992e-01,   6.73235343e-01,\n",
       "          6.89396695e-01,   7.05558047e-01,   7.21719399e-01,\n",
       "          7.37880751e-01,   7.54042103e-01,   7.70203454e-01,\n",
       "          7.86364806e-01,   8.02526158e-01,   8.18687510e-01,\n",
       "          8.34848862e-01,   8.51010214e-01,   8.67171566e-01,\n",
       "          8.83332917e-01,   8.99494269e-01,   9.15655621e-01,\n",
       "          9.31816973e-01,   9.47978325e-01,   9.64139677e-01,\n",
       "          9.80301028e-01,   9.96462380e-01,   1.01262373e+00,\n",
       "          1.02878508e+00,   1.04494644e+00,   1.06110779e+00,\n",
       "          1.07726914e+00,   1.09343049e+00,   1.10959184e+00,\n",
       "          1.12575320e+00,   1.14191455e+00,   1.15807590e+00,\n",
       "          1.17423725e+00,   1.19039860e+00,   1.20655995e+00,\n",
       "          1.22272131e+00,   1.23888266e+00,   1.25504401e+00,\n",
       "          1.27120536e+00,   1.28736671e+00,   1.30352807e+00,\n",
       "          1.31968942e+00,   1.33585077e+00,   1.35201212e+00,\n",
       "          1.36817347e+00,   1.38433482e+00,   1.40049618e+00,\n",
       "          1.41665753e+00,   1.43281888e+00,   1.44898023e+00,\n",
       "          1.46514158e+00,   1.48130294e+00,   1.49746429e+00,\n",
       "          1.51362564e+00,   1.52978699e+00,   1.54594834e+00,\n",
       "          1.56210969e+00,   1.57827105e+00,   1.59443240e+00,\n",
       "          1.61059375e+00,   1.62675510e+00,   1.64291645e+00,\n",
       "          1.65907781e+00,   1.67523916e+00,   1.69140051e+00,\n",
       "          1.70756186e+00,   1.72372321e+00,   1.73988457e+00,\n",
       "          1.75604592e+00,   1.77220727e+00,   1.78836862e+00,\n",
       "          1.80452997e+00,   1.82069132e+00,   1.83685268e+00,\n",
       "          1.85301403e+00],\n",
       "       [  1.22989539e+00,   1.15143628e+00,   1.07722393e+00,\n",
       "          1.00717881e+00,   9.41221419e-01,   8.79272234e-01,\n",
       "          8.21251741e-01,   7.67080426e-01,   7.16678772e-01,\n",
       "          6.69967266e-01,   6.26866391e-01,   5.87296632e-01,\n",
       "          5.51178474e-01,   5.18432402e-01,   4.88978901e-01,\n",
       "          4.62738455e-01,   4.39631549e-01,   4.19578668e-01,\n",
       "          4.02500297e-01,   3.88316920e-01,   3.76949022e-01,\n",
       "          3.68317088e-01,   3.62341603e-01,   3.58943051e-01,\n",
       "          3.58041918e-01,   3.59558687e-01,   3.63413845e-01,\n",
       "          3.69527874e-01,   3.77821261e-01,   3.88214491e-01,\n",
       "          4.00628046e-01,   4.14982414e-01,   4.31198078e-01,\n",
       "          4.49195522e-01,   4.68895233e-01,   4.90217694e-01,\n",
       "          5.13083391e-01,   5.37412808e-01,   5.63126429e-01,\n",
       "          5.90144741e-01,   6.18388226e-01,   6.47777371e-01,\n",
       "          6.78232660e-01,   7.09674578e-01,   7.42023609e-01,\n",
       "          7.75200238e-01,   8.09124950e-01,   8.43718230e-01,\n",
       "          8.78900563e-01,   9.14592432e-01,   9.50714324e-01,\n",
       "          9.87186723e-01,   1.02393011e+00,   1.06086498e+00,\n",
       "          1.09791181e+00,   1.13499108e+00,   1.17202328e+00,\n",
       "          1.20892890e+00,   1.24562842e+00,   1.28204233e+00,\n",
       "          1.31809110e+00,   1.35369523e+00,   1.38877520e+00,\n",
       "          1.42325149e+00,   1.45704459e+00,   1.49007498e+00,\n",
       "          1.52226316e+00,   1.55352959e+00,   1.58379478e+00,\n",
       "          1.61297919e+00,   1.64100332e+00,   1.66778766e+00,\n",
       "          1.69325268e+00,   1.71731887e+00,   1.73990672e+00,\n",
       "          1.76093671e+00,   1.78032933e+00,   1.79800506e+00,\n",
       "          1.81388438e+00,   1.82788778e+00,   1.83993575e+00,\n",
       "          1.84994877e+00,   1.85784732e+00,   1.86355189e+00,\n",
       "          1.86698296e+00,   1.86806103e+00,   1.86670656e+00,\n",
       "          1.86284006e+00,   1.85638200e+00,   1.84725286e+00,\n",
       "          1.83537314e+00,   1.82066332e+00,   1.80304388e+00,\n",
       "          1.78243530e+00,   1.75875808e+00,   1.73193269e+00,\n",
       "          1.70187963e+00,   1.66851936e+00,   1.63177240e+00,\n",
       "          1.59155920e+00],\n",
       "       [ -1.99554310e-01,  -3.95192724e-03,   1.79851752e-01,\n",
       "          3.51005136e-01,   5.08831706e-01,   6.52819233e-01,\n",
       "          7.82609240e-01,   8.97986721e-01,   9.98870117e-01,\n",
       "          1.08530155e+00,   1.15743729e+00,   1.21553852e+00,\n",
       "          1.25996233e+00,   1.29115292e+00,   1.30963316e+00,\n",
       "          1.31599632e+00,   1.31089811e+00,   1.29504889e+00,\n",
       "          1.26920626e+00,   1.23416782e+00,   1.19076415e+00,\n",
       "          1.13985218e+00,   1.08230867e+00,   1.01902405e+00,\n",
       "          9.50896441e-01,   8.78825970e-01,   8.03709344e-01,\n",
       "          7.26434655e-01,   6.47876457e-01,   5.68891088e-01,\n",
       "          4.90312256e-01,   4.12946874e-01,   3.37571147e-01,\n",
       "          2.64926923e-01,   1.95718291e-01,   1.30608438e-01,\n",
       "          7.02167560e-02,   1.51162118e-02,  -3.41690366e-02,\n",
       "         -7.71657636e-02,  -1.13453547e-01,  -1.42666382e-01,\n",
       "         -1.64494044e-01,  -1.78683194e-01,  -1.85038228e-01,\n",
       "         -1.83421873e-01,  -1.73755533e-01,  -1.56019368e-01,\n",
       "         -1.30252132e-01,  -9.65507462e-02,  -5.50696232e-02,\n",
       "         -6.01973198e-03,   5.03325883e-02,   1.13667071e-01,\n",
       "          1.83611221e-01,   2.59742264e-01,   3.41589357e-01,\n",
       "          4.28636046e-01,   5.20322987e-01,   6.16050916e-01,\n",
       "          7.15183874e-01,   8.17052690e-01,   9.20958717e-01,\n",
       "          1.02617782e+00,   1.13196463e+00,   1.23755703e+00,\n",
       "          1.34218093e+00,   1.44505526e+00,   1.54539723e+00,\n",
       "          1.64242789e+00,   1.73537785e+00,   1.82349336e+00,\n",
       "          1.90604254e+00,   1.98232198e+00,   2.05166348e+00,\n",
       "          2.11344114e+00,   2.16707864e+00,   2.21205680e+00,\n",
       "          2.24792141e+00,   2.27429129e+00,   2.29086658e+00,\n",
       "          2.29743739e+00,   2.29389257e+00,   2.28022881e+00,\n",
       "          2.25656001e+00,   2.22312684e+00,   2.18030664e+00,\n",
       "          2.12862347e+00,   2.06875850e+00,   2.00156065e+00,\n",
       "          1.92805743e+00,   1.84946605e+00,   1.76720485e+00,\n",
       "          1.68290491e+00,   1.59842194e+00,   1.51584842e+00,\n",
       "          1.43752602e+00,   1.36605824e+00,   1.30432333e+00,\n",
       "          1.25548743e+00],\n",
       "       [  6.79502285e+00,   4.14319957e+00,   2.23123322e+00,\n",
       "          9.10495532e-01,   5.49803315e-02,  -4.41344457e-01,\n",
       "         -6.66950444e-01,  -6.94942887e-01,  -5.85049614e-01,\n",
       "         -3.85418417e-01,  -1.34236065e-01,   1.38818559e-01,\n",
       "          4.11275202e-01,   6.66715442e-01,   8.93747460e-01,\n",
       "          1.08510202e+00,   1.23683979e+00,   1.34766069e+00,\n",
       "          1.41830632e+00,   1.45104724e+00,   1.44924694e+00,\n",
       "          1.41699534e+00,   1.35880444e+00,   1.27935985e+00,\n",
       "          1.18332182e+00,   1.07516995e+00,   9.59086410e-01,\n",
       "          8.38872457e-01,   7.17893658e-01,   5.99049596e-01,\n",
       "          4.84764051e-01,   3.76992063e-01,   2.77240599e-01,\n",
       "          1.86599822e-01,   1.05782272e-01,   3.51675757e-02,\n",
       "         -2.51494865e-02,  -7.53094019e-02,  -1.15638484e-01,\n",
       "         -1.46600958e-01,  -1.68753745e-01,  -1.82704910e-01,\n",
       "         -1.89076542e-01,  -1.88472636e-01,  -1.81452388e-01,\n",
       "         -1.68509141e-01,  -1.50055083e-01,  -1.26411638e-01,\n",
       "         -9.78053923e-02,  -6.43692604e-02,  -2.61485139e-02,\n",
       "          1.68888091e-02,   6.48376626e-02,   1.17838541e-01,\n",
       "          1.76057485e-01,   2.39664260e-01,   3.08809443e-01,\n",
       "          3.83601186e-01,   4.64082407e-01,   5.50209170e-01,\n",
       "          6.41830991e-01,   7.38673768e-01,   8.40326006e-01,\n",
       "          9.46228923e-01,   1.05567100e+00,   1.16778742e+00,\n",
       "          1.28156471e+00,   1.39585100e+00,   1.50937183e+00,\n",
       "          1.62075165e+00,   1.72854097e+00,   1.83124862e+00,\n",
       "          1.92737898e+00,   2.01547331e+00,   2.09415458e+00,\n",
       "          2.16217465e+00,   2.21846257e+00,   2.26217273e+00,\n",
       "          2.29273094e+00,   2.30987668e+00,   2.31369926e+00,\n",
       "          2.30466539e+00,   2.28363551e+00,   2.25186569e+00,\n",
       "          2.21099186e+00,   2.16299265e+00,   2.11012671e+00,\n",
       "          2.05484041e+00,   1.99964089e+00,   1.94692956e+00,\n",
       "          1.89879060e+00,   1.85672836e+00,   1.82134774e+00,\n",
       "          1.79197049e+00,   1.76618058e+00,   1.73929091e+00,\n",
       "          1.70372341e+00,   1.64829405e+00,   1.55739372e+00,\n",
       "          1.41005558e+00]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_one():\n",
    "    from sklearn.linear_model import LinearRegression\n",
    "    from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "    # Your code here\n",
    "    result = np.zeros((4, 100))\n",
    "    degrees = [1, 3, 6, 9]\n",
    "    for i, degree in enumerate(degrees):\n",
    "        poly = PolynomialFeatures(degree = degree)\n",
    "        X_train_poly = poly.fit_transform(X_train.reshape(11, 1))\n",
    "        linreg = LinearRegression().fit(X_train_poly, y_train)\n",
    "        prediction = linreg.predict(poly.fit_transform(np.linspace(0, 10, 100).reshape(100, 1)))\n",
    "        result[i, :] = prediction\n",
    "    \n",
    "    return result\n",
    "\n",
    "answer_one()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "/* Put everything inside the global mpl namespace */\n",
       "window.mpl = {};\n",
       "\n",
       "\n",
       "mpl.get_websocket_type = function() {\n",
       "    if (typeof(WebSocket) !== 'undefined') {\n",
       "        return WebSocket;\n",
       "    } else if (typeof(MozWebSocket) !== 'undefined') {\n",
       "        return MozWebSocket;\n",
       "    } else {\n",
       "        alert('Your browser does not have WebSocket support.' +\n",
       "              'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
       "              'Firefox 4 and 5 are also supported but you ' +\n",
       "              'have to enable WebSockets in about:config.');\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
       "    this.id = figure_id;\n",
       "\n",
       "    this.ws = websocket;\n",
       "\n",
       "    this.supports_binary = (this.ws.binaryType != undefined);\n",
       "\n",
       "    if (!this.supports_binary) {\n",
       "        var warnings = document.getElementById(\"mpl-warnings\");\n",
       "        if (warnings) {\n",
       "            warnings.style.display = 'block';\n",
       "            warnings.textContent = (\n",
       "                \"This browser does not support binary websocket messages. \" +\n",
       "                    \"Performance may be slow.\");\n",
       "        }\n",
       "    }\n",
       "\n",
       "    this.imageObj = new Image();\n",
       "\n",
       "    this.context = undefined;\n",
       "    this.message = undefined;\n",
       "    this.canvas = undefined;\n",
       "    this.rubberband_canvas = undefined;\n",
       "    this.rubberband_context = undefined;\n",
       "    this.format_dropdown = undefined;\n",
       "\n",
       "    this.image_mode = 'full';\n",
       "\n",
       "    this.root = $('<div/>');\n",
       "    this._root_extra_style(this.root)\n",
       "    this.root.attr('style', 'display: inline-block');\n",
       "\n",
       "    $(parent_element).append(this.root);\n",
       "\n",
       "    this._init_header(this);\n",
       "    this._init_canvas(this);\n",
       "    this._init_toolbar(this);\n",
       "\n",
       "    var fig = this;\n",
       "\n",
       "    this.waiting = false;\n",
       "\n",
       "    this.ws.onopen =  function () {\n",
       "            fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
       "            fig.send_message(\"send_image_mode\", {});\n",
       "            if (mpl.ratio != 1) {\n",
       "                fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
       "            }\n",
       "            fig.send_message(\"refresh\", {});\n",
       "        }\n",
       "\n",
       "    this.imageObj.onload = function() {\n",
       "            if (fig.image_mode == 'full') {\n",
       "                // Full images could contain transparency (where diff images\n",
       "                // almost always do), so we need to clear the canvas so that\n",
       "                // there is no ghosting.\n",
       "                fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
       "            }\n",
       "            fig.context.drawImage(fig.imageObj, 0, 0);\n",
       "        };\n",
       "\n",
       "    this.imageObj.onunload = function() {\n",
       "        this.ws.close();\n",
       "    }\n",
       "\n",
       "    this.ws.onmessage = this._make_on_message_function(this);\n",
       "\n",
       "    this.ondownload = ondownload;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_header = function() {\n",
       "    var titlebar = $(\n",
       "        '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
       "        'ui-helper-clearfix\"/>');\n",
       "    var titletext = $(\n",
       "        '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
       "        'text-align: center; padding: 3px;\"/>');\n",
       "    titlebar.append(titletext)\n",
       "    this.root.append(titlebar);\n",
       "    this.header = titletext[0];\n",
       "}\n",
       "\n",
       "\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_canvas = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var canvas_div = $('<div/>');\n",
       "\n",
       "    canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
       "\n",
       "    function canvas_keyboard_event(event) {\n",
       "        return fig.key_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    canvas_div.keydown('key_press', canvas_keyboard_event);\n",
       "    canvas_div.keyup('key_release', canvas_keyboard_event);\n",
       "    this.canvas_div = canvas_div\n",
       "    this._canvas_extra_style(canvas_div)\n",
       "    this.root.append(canvas_div);\n",
       "\n",
       "    var canvas = $('<canvas/>');\n",
       "    canvas.addClass('mpl-canvas');\n",
       "    canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
       "\n",
       "    this.canvas = canvas[0];\n",
       "    this.context = canvas[0].getContext(\"2d\");\n",
       "\n",
       "    var backingStore = this.context.backingStorePixelRatio ||\n",
       "\tthis.context.webkitBackingStorePixelRatio ||\n",
       "\tthis.context.mozBackingStorePixelRatio ||\n",
       "\tthis.context.msBackingStorePixelRatio ||\n",
       "\tthis.context.oBackingStorePixelRatio ||\n",
       "\tthis.context.backingStorePixelRatio || 1;\n",
       "\n",
       "    mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
       "\n",
       "    var rubberband = $('<canvas/>');\n",
       "    rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
       "\n",
       "    var pass_mouse_events = true;\n",
       "\n",
       "    canvas_div.resizable({\n",
       "        start: function(event, ui) {\n",
       "            pass_mouse_events = false;\n",
       "        },\n",
       "        resize: function(event, ui) {\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "        stop: function(event, ui) {\n",
       "            pass_mouse_events = true;\n",
       "            fig.request_resize(ui.size.width, ui.size.height);\n",
       "        },\n",
       "    });\n",
       "\n",
       "    function mouse_event_fn(event) {\n",
       "        if (pass_mouse_events)\n",
       "            return fig.mouse_event(event, event['data']);\n",
       "    }\n",
       "\n",
       "    rubberband.mousedown('button_press', mouse_event_fn);\n",
       "    rubberband.mouseup('button_release', mouse_event_fn);\n",
       "    // Throttle sequential mouse events to 1 every 20ms.\n",
       "    rubberband.mousemove('motion_notify', mouse_event_fn);\n",
       "\n",
       "    rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
       "    rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
       "\n",
       "    canvas_div.on(\"wheel\", function (event) {\n",
       "        event = event.originalEvent;\n",
       "        event['data'] = 'scroll'\n",
       "        if (event.deltaY < 0) {\n",
       "            event.step = 1;\n",
       "        } else {\n",
       "            event.step = -1;\n",
       "        }\n",
       "        mouse_event_fn(event);\n",
       "    });\n",
       "\n",
       "    canvas_div.append(canvas);\n",
       "    canvas_div.append(rubberband);\n",
       "\n",
       "    this.rubberband = rubberband;\n",
       "    this.rubberband_canvas = rubberband[0];\n",
       "    this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
       "    this.rubberband_context.strokeStyle = \"#000000\";\n",
       "\n",
       "    this._resize_canvas = function(width, height) {\n",
       "        // Keep the size of the canvas, canvas container, and rubber band\n",
       "        // canvas in synch.\n",
       "        canvas_div.css('width', width)\n",
       "        canvas_div.css('height', height)\n",
       "\n",
       "        canvas.attr('width', width * mpl.ratio);\n",
       "        canvas.attr('height', height * mpl.ratio);\n",
       "        canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
       "\n",
       "        rubberband.attr('width', width);\n",
       "        rubberband.attr('height', height);\n",
       "    }\n",
       "\n",
       "    // Set the figure to an initial 600x600px, this will subsequently be updated\n",
       "    // upon first draw.\n",
       "    this._resize_canvas(600, 600);\n",
       "\n",
       "    // Disable right mouse context menu.\n",
       "    $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
       "        return false;\n",
       "    });\n",
       "\n",
       "    function set_focus () {\n",
       "        canvas.focus();\n",
       "        canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    window.setTimeout(set_focus, 100);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items) {\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) {\n",
       "            // put a spacer in here.\n",
       "            continue;\n",
       "        }\n",
       "        var button = $('<button/>');\n",
       "        button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
       "                        'ui-button-icon-only');\n",
       "        button.attr('role', 'button');\n",
       "        button.attr('aria-disabled', 'false');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "\n",
       "        var icon_img = $('<span/>');\n",
       "        icon_img.addClass('ui-button-icon-primary ui-icon');\n",
       "        icon_img.addClass(image);\n",
       "        icon_img.addClass('ui-corner-all');\n",
       "\n",
       "        var tooltip_span = $('<span/>');\n",
       "        tooltip_span.addClass('ui-button-text');\n",
       "        tooltip_span.html(tooltip);\n",
       "\n",
       "        button.append(icon_img);\n",
       "        button.append(tooltip_span);\n",
       "\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    var fmt_picker_span = $('<span/>');\n",
       "\n",
       "    var fmt_picker = $('<select/>');\n",
       "    fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
       "    fmt_picker_span.append(fmt_picker);\n",
       "    nav_element.append(fmt_picker_span);\n",
       "    this.format_dropdown = fmt_picker[0];\n",
       "\n",
       "    for (var ind in mpl.extensions) {\n",
       "        var fmt = mpl.extensions[ind];\n",
       "        var option = $(\n",
       "            '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
       "        fmt_picker.append(option)\n",
       "    }\n",
       "\n",
       "    // Add hover states to the ui-buttons\n",
       "    $( \".ui-button\" ).hover(\n",
       "        function() { $(this).addClass(\"ui-state-hover\");},\n",
       "        function() { $(this).removeClass(\"ui-state-hover\");}\n",
       "    );\n",
       "\n",
       "    var status_bar = $('<span class=\"mpl-message\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
       "    // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
       "    // which will in turn request a refresh of the image.\n",
       "    this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_message = function(type, properties) {\n",
       "    properties['type'] = type;\n",
       "    properties['figure_id'] = this.id;\n",
       "    this.ws.send(JSON.stringify(properties));\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.send_draw_message = function() {\n",
       "    if (!this.waiting) {\n",
       "        this.waiting = true;\n",
       "        this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
       "    }\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    var format_dropdown = fig.format_dropdown;\n",
       "    var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
       "    fig.ondownload(fig, format);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
       "    var size = msg['size'];\n",
       "    if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
       "        fig._resize_canvas(size[0], size[1]);\n",
       "        fig.send_message(\"refresh\", {});\n",
       "    };\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
       "    var x0 = msg['x0'] / mpl.ratio;\n",
       "    var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
       "    var x1 = msg['x1'] / mpl.ratio;\n",
       "    var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
       "    x0 = Math.floor(x0) + 0.5;\n",
       "    y0 = Math.floor(y0) + 0.5;\n",
       "    x1 = Math.floor(x1) + 0.5;\n",
       "    y1 = Math.floor(y1) + 0.5;\n",
       "    var min_x = Math.min(x0, x1);\n",
       "    var min_y = Math.min(y0, y1);\n",
       "    var width = Math.abs(x1 - x0);\n",
       "    var height = Math.abs(y1 - y0);\n",
       "\n",
       "    fig.rubberband_context.clearRect(\n",
       "        0, 0, fig.canvas.width, fig.canvas.height);\n",
       "\n",
       "    fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
       "    // Updates the figure title.\n",
       "    fig.header.textContent = msg['label'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
       "    var cursor = msg['cursor'];\n",
       "    switch(cursor)\n",
       "    {\n",
       "    case 0:\n",
       "        cursor = 'pointer';\n",
       "        break;\n",
       "    case 1:\n",
       "        cursor = 'default';\n",
       "        break;\n",
       "    case 2:\n",
       "        cursor = 'crosshair';\n",
       "        break;\n",
       "    case 3:\n",
       "        cursor = 'move';\n",
       "        break;\n",
       "    }\n",
       "    fig.rubberband_canvas.style.cursor = cursor;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
       "    fig.message.textContent = msg['message'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
       "    // Request the server to send over a new figure.\n",
       "    fig.send_draw_message();\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
       "    fig.image_mode = msg['mode'];\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Called whenever the canvas gets updated.\n",
       "    this.send_message(\"ack\", {});\n",
       "}\n",
       "\n",
       "// A function to construct a web socket function for onmessage handling.\n",
       "// Called in the figure constructor.\n",
       "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
       "    return function socket_on_message(evt) {\n",
       "        if (evt.data instanceof Blob) {\n",
       "            /* FIXME: We get \"Resource interpreted as Image but\n",
       "             * transferred with MIME type text/plain:\" errors on\n",
       "             * Chrome.  But how to set the MIME type?  It doesn't seem\n",
       "             * to be part of the websocket stream */\n",
       "            evt.data.type = \"image/png\";\n",
       "\n",
       "            /* Free the memory for the previous frames */\n",
       "            if (fig.imageObj.src) {\n",
       "                (window.URL || window.webkitURL).revokeObjectURL(\n",
       "                    fig.imageObj.src);\n",
       "            }\n",
       "\n",
       "            fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
       "                evt.data);\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "        else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
       "            fig.imageObj.src = evt.data;\n",
       "            fig.updated_canvas_event();\n",
       "            fig.waiting = false;\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        var msg = JSON.parse(evt.data);\n",
       "        var msg_type = msg['type'];\n",
       "\n",
       "        // Call the  \"handle_{type}\" callback, which takes\n",
       "        // the figure and JSON message as its only arguments.\n",
       "        try {\n",
       "            var callback = fig[\"handle_\" + msg_type];\n",
       "        } catch (e) {\n",
       "            console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
       "            return;\n",
       "        }\n",
       "\n",
       "        if (callback) {\n",
       "            try {\n",
       "                // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
       "                callback(fig, msg);\n",
       "            } catch (e) {\n",
       "                console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
       "            }\n",
       "        }\n",
       "    };\n",
       "}\n",
       "\n",
       "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
       "mpl.findpos = function(e) {\n",
       "    //this section is from http://www.quirksmode.org/js/events_properties.html\n",
       "    var targ;\n",
       "    if (!e)\n",
       "        e = window.event;\n",
       "    if (e.target)\n",
       "        targ = e.target;\n",
       "    else if (e.srcElement)\n",
       "        targ = e.srcElement;\n",
       "    if (targ.nodeType == 3) // defeat Safari bug\n",
       "        targ = targ.parentNode;\n",
       "\n",
       "    // jQuery normalizes the pageX and pageY\n",
       "    // pageX,Y are the mouse positions relative to the document\n",
       "    // offset() returns the position of the element relative to the document\n",
       "    var x = e.pageX - $(targ).offset().left;\n",
       "    var y = e.pageY - $(targ).offset().top;\n",
       "\n",
       "    return {\"x\": x, \"y\": y};\n",
       "};\n",
       "\n",
       "/*\n",
       " * return a copy of an object with only non-object keys\n",
       " * we need this to avoid circular references\n",
       " * http://stackoverflow.com/a/24161582/3208463\n",
       " */\n",
       "function simpleKeys (original) {\n",
       "  return Object.keys(original).reduce(function (obj, key) {\n",
       "    if (typeof original[key] !== 'object')\n",
       "        obj[key] = original[key]\n",
       "    return obj;\n",
       "  }, {});\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.mouse_event = function(event, name) {\n",
       "    var canvas_pos = mpl.findpos(event)\n",
       "\n",
       "    if (name === 'button_press')\n",
       "    {\n",
       "        this.canvas.focus();\n",
       "        this.canvas_div.focus();\n",
       "    }\n",
       "\n",
       "    var x = canvas_pos.x * mpl.ratio;\n",
       "    var y = canvas_pos.y * mpl.ratio;\n",
       "\n",
       "    this.send_message(name, {x: x, y: y, button: event.button,\n",
       "                             step: event.step,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "\n",
       "    /* This prevents the web browser from automatically changing to\n",
       "     * the text insertion cursor when the button is pressed.  We want\n",
       "     * to control all of the cursor setting manually through the\n",
       "     * 'cursor' event from matplotlib */\n",
       "    event.preventDefault();\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    // Handle any extra behaviour associated with a key event\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.key_event = function(event, name) {\n",
       "\n",
       "    // Prevent repeat events\n",
       "    if (name == 'key_press')\n",
       "    {\n",
       "        if (event.which === this._key)\n",
       "            return;\n",
       "        else\n",
       "            this._key = event.which;\n",
       "    }\n",
       "    if (name == 'key_release')\n",
       "        this._key = null;\n",
       "\n",
       "    var value = '';\n",
       "    if (event.ctrlKey && event.which != 17)\n",
       "        value += \"ctrl+\";\n",
       "    if (event.altKey && event.which != 18)\n",
       "        value += \"alt+\";\n",
       "    if (event.shiftKey && event.which != 16)\n",
       "        value += \"shift+\";\n",
       "\n",
       "    value += 'k';\n",
       "    value += event.which.toString();\n",
       "\n",
       "    this._key_event_extra(event, name);\n",
       "\n",
       "    this.send_message(name, {key: value,\n",
       "                             guiEvent: simpleKeys(event)});\n",
       "    return false;\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
       "    if (name == 'download') {\n",
       "        this.handle_save(this, null);\n",
       "    } else {\n",
       "        this.send_message(\"toolbar_button\", {name: name});\n",
       "    }\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
       "    this.message.textContent = tooltip;\n",
       "};\n",
       "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to  previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
       "\n",
       "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
       "\n",
       "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
       "    // Create a \"websocket\"-like object which calls the given IPython comm\n",
       "    // object with the appropriate methods. Currently this is a non binary\n",
       "    // socket, so there is still some room for performance tuning.\n",
       "    var ws = {};\n",
       "\n",
       "    ws.close = function() {\n",
       "        comm.close()\n",
       "    };\n",
       "    ws.send = function(m) {\n",
       "        //console.log('sending', m);\n",
       "        comm.send(m);\n",
       "    };\n",
       "    // Register the callback with on_msg.\n",
       "    comm.on_msg(function(msg) {\n",
       "        //console.log('receiving', msg['content']['data'], msg);\n",
       "        // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
       "        ws.onmessage(msg['content']['data'])\n",
       "    });\n",
       "    return ws;\n",
       "}\n",
       "\n",
       "mpl.mpl_figure_comm = function(comm, msg) {\n",
       "    // This is the function which gets called when the mpl process\n",
       "    // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
       "\n",
       "    var id = msg.content.data.id;\n",
       "    // Get hold of the div created by the display call when the Comm\n",
       "    // socket was opened in Python.\n",
       "    var element = $(\"#\" + id);\n",
       "    var ws_proxy = comm_websocket_adapter(comm)\n",
       "\n",
       "    function ondownload(figure, format) {\n",
       "        window.open(figure.imageObj.src);\n",
       "    }\n",
       "\n",
       "    var fig = new mpl.figure(id, ws_proxy,\n",
       "                           ondownload,\n",
       "                           element.get(0));\n",
       "\n",
       "    // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
       "    // web socket which is closed, not our websocket->open comm proxy.\n",
       "    ws_proxy.onopen();\n",
       "\n",
       "    fig.parent_element = element.get(0);\n",
       "    fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
       "    if (!fig.cell_info) {\n",
       "        console.error(\"Failed to find cell for figure\", id, fig);\n",
       "        return;\n",
       "    }\n",
       "\n",
       "    var output_index = fig.cell_info[2]\n",
       "    var cell = fig.cell_info[0];\n",
       "\n",
       "};\n",
       "\n",
       "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
       "    var width = fig.canvas.width/mpl.ratio\n",
       "    fig.root.unbind('remove')\n",
       "\n",
       "    // Update the output cell to use the data from the current canvas.\n",
       "    fig.push_to_output();\n",
       "    var dataURL = fig.canvas.toDataURL();\n",
       "    // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
       "    // the notebook keyboard shortcuts fail.\n",
       "    IPython.keyboard_manager.enable()\n",
       "    $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
       "    fig.close_ws(fig, msg);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.close_ws = function(fig, msg){\n",
       "    fig.send_message('closing', msg);\n",
       "    // fig.ws.close()\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
       "    // Turn the data on the canvas into data in the output cell.\n",
       "    var width = this.canvas.width/mpl.ratio\n",
       "    var dataURL = this.canvas.toDataURL();\n",
       "    this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.updated_canvas_event = function() {\n",
       "    // Tell IPython that the notebook contents must change.\n",
       "    IPython.notebook.set_dirty(true);\n",
       "    this.send_message(\"ack\", {});\n",
       "    var fig = this;\n",
       "    // Wait a second, then push the new image to the DOM so\n",
       "    // that it is saved nicely (might be nice to debounce this).\n",
       "    setTimeout(function () { fig.push_to_output() }, 1000);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._init_toolbar = function() {\n",
       "    var fig = this;\n",
       "\n",
       "    var nav_element = $('<div/>')\n",
       "    nav_element.attr('style', 'width: 100%');\n",
       "    this.root.append(nav_element);\n",
       "\n",
       "    // Define a callback function for later on.\n",
       "    function toolbar_event(event) {\n",
       "        return fig.toolbar_button_onclick(event['data']);\n",
       "    }\n",
       "    function toolbar_mouse_event(event) {\n",
       "        return fig.toolbar_button_onmouseover(event['data']);\n",
       "    }\n",
       "\n",
       "    for(var toolbar_ind in mpl.toolbar_items){\n",
       "        var name = mpl.toolbar_items[toolbar_ind][0];\n",
       "        var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
       "        var image = mpl.toolbar_items[toolbar_ind][2];\n",
       "        var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
       "\n",
       "        if (!name) { continue; };\n",
       "\n",
       "        var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
       "        button.click(method_name, toolbar_event);\n",
       "        button.mouseover(tooltip, toolbar_mouse_event);\n",
       "        nav_element.append(button);\n",
       "    }\n",
       "\n",
       "    // Add the status bar.\n",
       "    var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
       "    nav_element.append(status_bar);\n",
       "    this.message = status_bar[0];\n",
       "\n",
       "    // Add the close button to the window.\n",
       "    var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
       "    var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
       "    button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
       "    button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
       "    buttongrp.append(button);\n",
       "    var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
       "    titlebar.prepend(buttongrp);\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._root_extra_style = function(el){\n",
       "    var fig = this\n",
       "    el.on(\"remove\", function(){\n",
       "\tfig.close_ws(fig, {});\n",
       "    });\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._canvas_extra_style = function(el){\n",
       "    // this is important to make the div 'focusable\n",
       "    el.attr('tabindex', 0)\n",
       "    // reach out to IPython and tell the keyboard manager to turn it's self\n",
       "    // off when our div gets focus\n",
       "\n",
       "    // location in version 3\n",
       "    if (IPython.notebook.keyboard_manager) {\n",
       "        IPython.notebook.keyboard_manager.register_events(el);\n",
       "    }\n",
       "    else {\n",
       "        // location in version 2\n",
       "        IPython.keyboard_manager.register_events(el);\n",
       "    }\n",
       "\n",
       "}\n",
       "\n",
       "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
       "    var manager = IPython.notebook.keyboard_manager;\n",
       "    if (!manager)\n",
       "        manager = IPython.keyboard_manager;\n",
       "\n",
       "    // Check for shift+enter\n",
       "    if (event.shiftKey && event.which == 13) {\n",
       "        this.canvas_div.blur();\n",
       "        // select the cell after this one\n",
       "        var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
       "        IPython.notebook.select(index + 1);\n",
       "    }\n",
       "}\n",
       "\n",
       "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
       "    fig.ondownload(fig, null);\n",
       "}\n",
       "\n",
       "\n",
       "mpl.find_output_cell = function(html_output) {\n",
       "    // Return the cell and output element which can be found *uniquely* in the notebook.\n",
       "    // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
       "    // IPython event is triggered only after the cells have been serialised, which for\n",
       "    // our purposes (turning an active figure into a static one), is too late.\n",
       "    var cells = IPython.notebook.get_cells();\n",
       "    var ncells = cells.length;\n",
       "    for (var i=0; i<ncells; i++) {\n",
       "        var cell = cells[i];\n",
       "        if (cell.cell_type === 'code'){\n",
       "            for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
       "                var data = cell.output_area.outputs[j];\n",
       "                if (data.data) {\n",
       "                    // IPython >= 3 moved mimebundle to data attribute of output\n",
       "                    data = data.data;\n",
       "                }\n",
       "                if (data['text/html'] == html_output) {\n",
       "                    return [cell, data, j];\n",
       "                }\n",
       "            }\n",
       "        }\n",
       "    }\n",
       "}\n",
       "\n",
       "// Register the function which deals with the matplotlib target/channel.\n",
       "// The kernel may be null if the page has been refreshed.\n",
       "if (IPython.notebook.kernel != null) {\n",
       "    IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
       "}\n"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"\" width=\"1000\">"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# feel free to use the function plot_one() to replicate the figure \n",
    "# from the prompt once you have completed question one\n",
    "def plot_one(degree_predictions):\n",
    "    import matplotlib.pyplot as plt\n",
    "    %matplotlib notebook\n",
    "    plt.figure(figsize=(10,5))\n",
    "    plt.plot(X_train, y_train, 'o', label='training data', markersize=10)\n",
    "    plt.plot(X_test, y_test, 'o', label='test data', markersize=10)\n",
    "    for i,degree in enumerate([1,3,6,9]):\n",
    "        plt.plot(np.linspace(0,10,100), degree_predictions[i], alpha=0.8, lw=2, label='degree={}'.format(degree))\n",
    "    plt.ylim(-1,2.5)\n",
    "    plt.legend(loc=4)\n",
    "\n",
    "# plot_one(answer_one())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 2\n",
    "\n",
    "Write a function that fits a polynomial LinearRegression model on the training data `X_train` for degrees 0 through 9. For each model compute the $R^2$ (coefficient of determination) regression score on the training data as well as the the test data, and return both of these arrays in a tuple.\n",
    "\n",
    "*This function should return one tuple of numpy arrays `(r2_train, r2_test)`. Both arrays should have shape `(10,)`*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([0.0,\n",
       "  0.42924577812346632,\n",
       "  0.4510998044408247,\n",
       "  0.58719953687798465,\n",
       "  0.91941944717693436,\n",
       "  0.97578641430682345,\n",
       "  0.99018233247950815,\n",
       "  0.99352509278403633,\n",
       "  0.99637545387765036,\n",
       "  0.99803706256653058],\n",
       " [-0.47808641737141788,\n",
       "  -0.45237104233936676,\n",
       "  -0.068569841499158901,\n",
       "  0.0053310529457363254,\n",
       "  0.73004942818707141,\n",
       "  0.87708300916999382,\n",
       "  0.92140939814150025,\n",
       "  0.92021504127165776,\n",
       "  0.63247950833243671,\n",
       "  -0.64525377099038894])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_two():\n",
    "    from sklearn.linear_model import LinearRegression\n",
    "    from sklearn.preprocessing import PolynomialFeatures\n",
    "    from sklearn.metrics.regression import r2_score\n",
    "\n",
    "    # Your code here\n",
    "    r2_train, r2_test = [], []\n",
    "    \n",
    "    for degree in range(0,10):\n",
    "        poly = PolynomialFeatures(degree = degree)\n",
    "        X_train_poly = poly.fit_transform(X_train.reshape(11, 1))\n",
    "        X_test_poly = poly.fit_transform(X_test.reshape(4, 1))\n",
    "        linreg = LinearRegression().fit(X_train_poly, y_train)\n",
    "        r2_train.append(linreg.score(X_train_poly, y_train))\n",
    "        r2_test.append(linreg.score(X_test_poly, y_test))\n",
    "\n",
    "    return (r2_train, r2_test)\n",
    "\n",
    "answer_two()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 3\n",
    "\n",
    "Based on the $R^2$ scores from question 2 (degree levels 0 through 9), what degree level corresponds to a model that is underfitting? What degree level corresponds to a model that is overfitting? What choice of degree level would provide a model with good generalization performance on this dataset? \n",
    "\n",
    "Hint: Try plotting the $R^2$ scores from question 2 to visualize the relationship between degree level and $R^2$. Remember to comment out the import matplotlib line before submission.\n",
    "\n",
    "*This function should return one tuple with the degree values in this order: `(Underfitting, Overfitting, Good_Generalization)`. There might be multiple correct solutions, however, you only need to return one possible solution, for example, (1,2,3).* "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 9, 6)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_three():\n",
    "    \n",
    "    # Your code here\n",
    "    r2_score = answer_two()\n",
    "    df = pd.DataFrame({'training score' : r2_score[0], 'test score' : r2_score[1]}, index = range(0,10))\n",
    "    df['diff'] = df['training score'] - df['test score']\n",
    "    \n",
    "    df = df.sort_values(['training score'])\n",
    "    Underfitting = df.index[0]\n",
    "    \n",
    "    df = df.sort_values(['diff'])\n",
    "    Overfitting = df.index[9]\n",
    "    Good_Generalization = df.index[0]\n",
    "    \n",
    "    return (Underfitting, Overfitting, Good_Generalization)\n",
    "\n",
    "answer_three()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 4\n",
    "\n",
    "Training models on high degree polynomial features can result in overly complex models that overfit, so we often use regularized versions of the model to constrain model complexity, as we saw with Ridge and Lasso linear regression.\n",
    "\n",
    "For this question, train two models: a non-regularized LinearRegression model (default parameters) and a regularized Lasso Regression model (with parameters `alpha=0.01`, `max_iter=10000`) both on polynomial features of degree 12. Return the $R^2$ score for both the LinearRegression and Lasso model's test sets.\n",
    "\n",
    "*This function should return one tuple `(LinearRegression_R2_test_score, Lasso_R2_test_score)`*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/sklearn/linear_model/coordinate_descent.py:484: ConvergenceWarning: Objective did not converge. You might want to increase the number of iterations. Fitting data with very small alpha may cause precision problems.\n",
      "  ConvergenceWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(-4.3120017974975458, 0.8406625614750235)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_four():\n",
    "    from sklearn.preprocessing import PolynomialFeatures\n",
    "    from sklearn.linear_model import Lasso, LinearRegression\n",
    "    from sklearn.metrics.regression import r2_score\n",
    "\n",
    "    # Your code here\n",
    "    poly = PolynomialFeatures(degree = 12)\n",
    "    X_train_poly = poly.fit_transform(X_train.reshape(11, 1))\n",
    "    X_test_poly = poly.fit_transform(y_test.reshape(4, 1))\n",
    "    linreg = LinearRegression().fit(X_train_poly, y_train)\n",
    "    lassoreg = Lasso(alpha=0.01, max_iter=10000).fit(X_train_poly, y_train)\n",
    "    LinearRegression_R2_test_score = linreg.score(X_test_poly, y_test)\n",
    "    Lasso_R2_test_score = lassoreg.score(X_test_poly, y_test)\n",
    "\n",
    "    return (LinearRegression_R2_test_score, Lasso_R2_test_score)\n",
    "\n",
    "answer_four()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 2 - Classification\n",
    "\n",
    "Here's an application of machine learning that could save your life! For this section of the assignment we will be working with the [UCI Mushroom Data Set](http://archive.ics.uci.edu/ml/datasets/Mushroom?ref=datanews.io) stored in `readonly/mushrooms.csv`. The data will be used to train a model to predict whether or not a mushroom is poisonous. The following attributes are provided:\n",
    "\n",
    "*Attribute Information:*\n",
    "\n",
    "1. cap-shape: bell=b, conical=c, convex=x, flat=f, knobbed=k, sunken=s \n",
    "2. cap-surface: fibrous=f, grooves=g, scaly=y, smooth=s \n",
    "3. cap-color: brown=n, buff=b, cinnamon=c, gray=g, green=r, pink=p, purple=u, red=e, white=w, yellow=y \n",
    "4. bruises?: bruises=t, no=f \n",
    "5. odor: almond=a, anise=l, creosote=c, fishy=y, foul=f, musty=m, none=n, pungent=p, spicy=s \n",
    "6. gill-attachment: attached=a, descending=d, free=f, notched=n \n",
    "7. gill-spacing: close=c, crowded=w, distant=d \n",
    "8. gill-size: broad=b, narrow=n \n",
    "9. gill-color: black=k, brown=n, buff=b, chocolate=h, gray=g, green=r, orange=o, pink=p, purple=u, red=e, white=w, yellow=y \n",
    "10. stalk-shape: enlarging=e, tapering=t \n",
    "11. stalk-root: bulbous=b, club=c, cup=u, equal=e, rhizomorphs=z, rooted=r, missing=? \n",
    "12. stalk-surface-above-ring: fibrous=f, scaly=y, silky=k, smooth=s \n",
    "13. stalk-surface-below-ring: fibrous=f, scaly=y, silky=k, smooth=s \n",
    "14. stalk-color-above-ring: brown=n, buff=b, cinnamon=c, gray=g, orange=o, pink=p, red=e, white=w, yellow=y \n",
    "15. stalk-color-below-ring: brown=n, buff=b, cinnamon=c, gray=g, orange=o, pink=p, red=e, white=w, yellow=y \n",
    "16. veil-type: partial=p, universal=u \n",
    "17. veil-color: brown=n, orange=o, white=w, yellow=y \n",
    "18. ring-number: none=n, one=o, two=t \n",
    "19. ring-type: cobwebby=c, evanescent=e, flaring=f, large=l, none=n, pendant=p, sheathing=s, zone=z \n",
    "20. spore-print-color: black=k, brown=n, buff=b, chocolate=h, green=r, orange=o, purple=u, white=w, yellow=y \n",
    "21. population: abundant=a, clustered=c, numerous=n, scattered=s, several=v, solitary=y \n",
    "22. habitat: grasses=g, leaves=l, meadows=m, paths=p, urban=u, waste=w, woods=d\n",
    "\n",
    "<br>\n",
    "\n",
    "The data in the mushrooms dataset is currently encoded with strings. These values will need to be encoded to numeric to work with sklearn. We'll use pd.get_dummies to convert the categorical variables into indicator variables. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "mush_df = pd.read_csv('mushrooms.csv')\n",
    "mush_df2 = pd.get_dummies(mush_df)\n",
    "\n",
    "X_mush = mush_df2.iloc[:,2:]\n",
    "y_mush = mush_df2.iloc[:,1]\n",
    "\n",
    "# use the variables X_train2, y_train2 for Question 5\n",
    "X_train2, X_test2, y_train2, y_test2 = train_test_split(X_mush, y_mush, random_state=0)\n",
    "\n",
    "# For performance reasons in Questions 6 and 7, we will create a smaller version of the\n",
    "# entire mushroom dataset for use in those questions.  For simplicity we'll just re-use\n",
    "# the 25% test split created above as the representative subset.\n",
    "#\n",
    "# Use the variables X_subset, y_subset for Questions 6 and 7.\n",
    "X_subset = X_test2\n",
    "y_subset = y_test2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 5\n",
    "\n",
    "Using `X_train2` and `y_train2` from the preceeding cell, train a DecisionTreeClassifier with default parameters and random_state=0. What are the 5 most important features found by the decision tree?\n",
    "\n",
    "As a reminder, the feature names are available in the `X_train2.columns` property, and the order of the features in `X_train2.columns` matches the order of the feature importance values in the classifier's `feature_importances_` property. \n",
    "\n",
    "*This function should return a list of length 5 containing the feature names in descending order of importance.*\n",
    "\n",
    "*Note: remember that you also need to set random_state in the DecisionTreeClassifier.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['odor_n', 'stalk-root_c', 'stalk-root_r', 'spore-print-color_r', 'odor_l']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_five():\n",
    "    from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "    # Your code here\n",
    "    clf = DecisionTreeClassifier(random_state = 0).fit(X_train2, y_train2)\n",
    "    \n",
    "    df = pd.DataFrame({'Feature Names' : X_train2.columns, 'Importance' : clf.feature_importances_})\n",
    "    \n",
    "    return list(df.sort_values(['Importance'], ascending = False)['Feature Names'].head(5))\n",
    "\n",
    "answer_five()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 6\n",
    "\n",
    "For this question, we're going to use the `validation_curve` function in `sklearn.model_selection` to determine training and test scores for a Support Vector Classifier (`SVC`) with varying parameter values.  Recall that the validation_curve function, in addition to taking an initialized unfitted classifier object, takes a dataset as input and does its own internal train-test splits to compute results.\n",
    "\n",
    "**Because creating a validation curve requires fitting multiple models, for performance reasons this question will use just a subset of the original mushroom dataset: please use the variables X_subset and y_subset as input to the validation curve function (instead of X_mush and y_mush) to reduce computation time.**\n",
    "\n",
    "The initialized unfitted classifier object we'll be using is a Support Vector Classifier with radial basis kernel.  So your first step is to create an `SVC` object with default parameters (i.e. `kernel='rbf', C=1`) and `random_state=0`. Recall that the kernel width of the RBF kernel is controlled using the `gamma` parameter.  \n",
    "\n",
    "With this classifier, and the dataset in X_subset, y_subset, explore the effect of `gamma` on classifier accuracy by using the `validation_curve` function to find the training and test scores for 6 values of `gamma` from `0.0001` to `10` (i.e. `np.logspace(-4,1,6)`). Recall that you can specify what scoring metric you want validation_curve to use by setting the \"scoring\" parameter.  In this case, we want to use \"accuracy\" as the scoring metric.\n",
    "\n",
    "For each level of `gamma`, `validation_curve` will fit 3 models on different subsets of the data, returning two 6x3 (6 levels of gamma x 3 fits per level) arrays of the scores for the training and test sets.\n",
    "\n",
    "Find the mean score across the three models for each level of `gamma` for both arrays, creating two arrays of length 6, and return a tuple with the two arrays.\n",
    "\n",
    "e.g.\n",
    "\n",
    "if one of your array of scores is\n",
    "\n",
    "    array([[ 0.5,  0.4,  0.6],\n",
    "           [ 0.7,  0.8,  0.7],\n",
    "           [ 0.9,  0.8,  0.8],\n",
    "           [ 0.8,  0.7,  0.8],\n",
    "           [ 0.7,  0.6,  0.6],\n",
    "           [ 0.4,  0.6,  0.5]])\n",
    "       \n",
    "it should then become\n",
    "\n",
    "    array([ 0.5,  0.73333333,  0.83333333,  0.76666667,  0.63333333, 0.5])\n",
    "\n",
    "*This function should return one tuple of numpy arrays `(training_scores, test_scores)` where each array in the tuple has shape `(6,)`.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 0.56647847,  0.93155951,  0.99039881,  1.        ,  1.        ,  1.        ]),\n",
       " array([ 0.56768547,  0.92959558,  0.98965952,  1.        ,  0.99507994,\n",
       "         0.52240279]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_six():\n",
    "    from sklearn.svm import SVC\n",
    "    from sklearn.model_selection import validation_curve\n",
    "\n",
    "    # Your code here\n",
    "    training_scores, test_scores = validation_curve(SVC(kernel='rbf', C=1, random_state = 0), X_subset, y_subset,\n",
    "                                            param_name='gamma',\n",
    "                                            param_range=np.logspace(-4, 1, 6), cv=3)\n",
    "\n",
    "    return (training_scores.mean(axis = 1), test_scores.mean(axis = 1))\n",
    "    # Alternatives: \n",
    "#     return np.array(list(map(np.mean, training_scores))), np.array(list(map(np.mean, test_scores)))\n",
    "\n",
    "answer_six()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Question 7\n",
    "\n",
    "Based on the scores from question 6, what gamma value corresponds to a model that is underfitting (and has the worst test set accuracy)? What gamma value corresponds to a model that is overfitting (and has the worst test set accuracy)? What choice of gamma would be the best choice for a model with good generalization performance on this dataset (high accuracy on both training and test set)? \n",
    "\n",
    "Hint: Try plotting the scores from question 6 to visualize the relationship between gamma and accuracy. Remember to comment out the import matplotlib line before submission.\n",
    "\n",
    "*This function should return one tuple with the degree values in this order: `(Underfitting, Overfitting, Good_Generalization)` Please note there is only one correct solution.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.0001, 10.0, 0.10000000000000001)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def answer_seven():\n",
    "    \n",
    "    # Your code here\n",
    "    train_score, test_score = answer_six()\n",
    "    df = pd.DataFrame({'Train_Score' : train_score, 'Test_Score' : test_score, 'gamma' : np.logspace(-4, 1, 6)})\n",
    "    df['diff'] = df['Train_Score'] - df['Test_Score']\n",
    "    \n",
    "    Underfitting = df.sort_values(['Train_Score']).iloc[0]['gamma']\n",
    "    \n",
    "    Overfitting = df.sort_values(['diff']).iloc[5]['gamma']\n",
    "    \n",
    "    Good_Generalization = df.sort_values(['diff']).iloc[1]['gamma']\n",
    "    \n",
    "    return (Underfitting, Overfitting, Good_Generalization)\n",
    "\n",
    "answer_seven()"
   ]
  }
 ],
 "metadata": {
  "coursera": {
   "course_slug": "python-machine-learning",
   "graded_item_id": "eWYHL",
   "launcher_item_id": "BAqef",
   "part_id": "fXXRp"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
